
import torch

from .lenet5 import Lenet5
from .vgg import VGG, VGG_types
from .alexnet import AlexNet
from ..wrapper.general_wrapper import DLRTNetwork


def choose_model(model_name: str, baseline: bool, tucker: bool, mat_dlrt: bool, adaptive: bool, tau: float, device,
                 chain_init: bool = False, dataset_name: str = "cifar10", load_model_path: str = None, load_weights: bool = True):
    """
    selects a model by name
    """
    if load_model_path:
        print("Loading model from file: " + load_model_path)
        f = torch.load(load_model_path)
        f.to(device)
    else:
        if dataset_name == "cifar10":  
            in_channels = 3
            out_classes = 10

        if model_name == 'vgg16':
            f = DLRTNetwork(VGG(VGG_types['VGG16'], in_channels, 32, 32, 256, out_classes), adaptive=adaptive,
                            tucker=tucker,
                            matrix_dlrt=mat_dlrt, tau={'linear': 0.0, 'conv2d': tau}, dense_first_layer=True,
                            baseline=baseline,load_weights=load_weights)
            f.to(device)


        elif model_name == 'lenet5':
            f = DLRTNetwork(Lenet5(), adaptive=adaptive, tucker=tucker, matrix_dlrt=mat_dlrt,
                            tau={'linear': tau, 'conv2d': tau}, baseline=baseline, chain_init=chain_init,load_weights=load_weights)
            f.to(device)


        elif model_name == 'alexnet':
            f = DLRTNetwork(AlexNet(output_dim=out_classes), adaptive=adaptive, tucker=tucker, matrix_dlrt=mat_dlrt,
                            tau={'linear': 0.0, 'conv2d': tau}, baseline=baseline, chain_init=chain_init,load_weights=load_weights)
            # f = DLRTNetwork(torchvision.models.alexnet(weights = torchvision.models.AlexNet_Weights), adaptive=adaptive, tucker=tucker, matrix_dlrt=mat_dlrt,
            #                 tau={'linear': 0.0, 'conv2d': tau}, baseline=baseline, chain_init=chain_init,load_weights=load_weights,sequential=False)
            f.to(device)
        else:
            print("Model not available. Exiting")
            exit(1)

    return f
